{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Integrate BrainPy models into Flax (Example 1)\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/integrate_bp_lif_into_flax.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/integrate_bp_lif_into_flax.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we use ``brainpy.neurons.LIF`` as a recurrent cell in the Flax computation. ``brainpy.neurons.LIF`` only has recurrent variables, does not have trainable parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "import ml_collections\n",
    "import numpy as np\n",
    "import optax\n",
    "import tensorflow_datasets as tfds\n",
    "from flax import linen as nn\n",
    "from flax.metrics import tensorboard\n",
    "from flax.training import train_state\n",
    "\n",
    "import brainpy as bp\n",
    "import brainpy.math as bm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "bm.set(mode=bm.training_mode, dt=1.)\n",
    "\n",
    "bp.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "num_time = 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "pars = dict(tau=10, V_reset=0, V_rest=0, V_th=0.1, keep_size=True, input_var=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "# LIF neurons can be viewed as a recurrent cell without trainable parameters\n",
    "cell1 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))\n",
    "cell2 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))\n",
    "cell3 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "class CNN(nn.Module):\n",
    "  @nn.compact\n",
    "  def __call__(self, x):\n",
    "    x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n",
    "    x = nn.RNN(cell1, cell1.model.varshape)(x)  # Use RNN to unfold the recurrent LIF\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n",
    "    x = nn.RNN(cell2, cell2.model.varshape)(x)  # Use RNN to unfold the recurrent LIF\n",
    "    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n",
    "    x = x.reshape((x.shape[0], x.shape[1], -1))\n",
    "    x = nn.Dense(features=256)(x)\n",
    "    x = nn.RNN(cell3, cell3.model.varshape)(x)  # Use RNN to unfold the recurrent LIF\n",
    "    x = nn.Dense(features=10)(x)\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def apply_model(state, images, labels):\n",
    "  \"\"\"Computes gradients, loss and accuracy for a single batch.\"\"\"\n",
    "  images = jnp.expand_dims(images, axis=1)\n",
    "  images = jnp.tile(images, (1, num_time, 1, 1, 1))\n",
    "\n",
    "  def loss_fn(params):\n",
    "    logits = state.apply_fn({'params': params}, images)\n",
    "    logits = bm.max(logits, axis=1).value\n",
    "    one_hot = jax.nn.one_hot(labels, 10)\n",
    "    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))\n",
    "    return loss, logits\n",
    "\n",
    "  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)\n",
    "  (loss, logits), grads = grad_fn(state.params)\n",
    "  accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)\n",
    "  return grads, loss, accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "@jax.jit\n",
    "def update_model(state, grads):\n",
    "  return state.apply_gradients(grads=grads)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "def train_epoch(state, train_ds, batch_size, rng):\n",
    "  \"\"\"Train for a single epoch.\"\"\"\n",
    "  train_ds_size = len(train_ds['image'])\n",
    "  steps_per_epoch = train_ds_size // batch_size\n",
    "\n",
    "  perms = jax.random.permutation(rng, len(train_ds['image']))\n",
    "  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch\n",
    "  perms = perms.reshape((steps_per_epoch, batch_size))\n",
    "\n",
    "  epoch_loss = []\n",
    "  epoch_accuracy = []\n",
    "\n",
    "  for perm in perms:\n",
    "    batch_images = train_ds['image'][perm, ...]\n",
    "    batch_labels = train_ds['label'][perm, ...]\n",
    "    grads, loss, accuracy = apply_model(state, batch_images, batch_labels)\n",
    "    state = update_model(state, grads)\n",
    "    epoch_loss.append(loss)\n",
    "    epoch_accuracy.append(accuracy)\n",
    "  train_loss = np.mean(epoch_loss)\n",
    "  train_accuracy = np.mean(epoch_accuracy)\n",
    "  return state, train_loss, train_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "def get_datasets():\n",
    "  \"\"\"Load MNIST train and test datasets into memory.\"\"\"\n",
    "  ds_builder = tfds.builder('mnist')\n",
    "  ds_builder.download_and_prepare()\n",
    "  train_ds = tfds.as_numpy(ds_builder.as_dataset(split='train', batch_size=-1))\n",
    "  test_ds = tfds.as_numpy(ds_builder.as_dataset(split='test', batch_size=-1))\n",
    "  train_ds['image'] = jnp.asarray(train_ds['image']) / 255.\n",
    "  test_ds['image'] = jnp.asarray(test_ds['image']) / 255.\n",
    "  return train_ds, test_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "def create_train_state(rng, config):\n",
    "  \"\"\"Creates initial `TrainState`.\"\"\"\n",
    "  cnn = CNN()\n",
    "  params = cnn.init(rng, jnp.ones([1, num_time, 28, 28, 1]))['params']\n",
    "  tx = optax.sgd(config.learning_rate, config.momentum)\n",
    "  return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "def train_and_evaluate(config: ml_collections.ConfigDict,\n",
    "                       workdir: str) -> train_state.TrainState:\n",
    "  \"\"\"Execute model training and evaluation loop.\n",
    "\n",
    "  Args:\n",
    "    config: Hyperparameter configuration for training and evaluation.\n",
    "    workdir: Directory where the tensorboard summaries are written to.\n",
    "\n",
    "  Returns:\n",
    "    The train state (which includes the `.params`).\n",
    "  \"\"\"\n",
    "  train_ds, test_ds = get_datasets()\n",
    "  rng = jax.random.PRNGKey(0)\n",
    "\n",
    "  summary_writer = tensorboard.SummaryWriter(workdir)\n",
    "  summary_writer.hparams(dict(config))\n",
    "\n",
    "  rng, init_rng = jax.random.split(rng)\n",
    "  state = create_train_state(init_rng, config)\n",
    "\n",
    "  for epoch in range(1, config.num_epochs + 1):\n",
    "    rng, input_rng = jax.random.split(rng)\n",
    "    state, train_loss, train_accuracy = train_epoch(state,\n",
    "                                                    train_ds,\n",
    "                                                    config.batch_size,\n",
    "                                                    input_rng)\n",
    "    test_losses, test_accs = [], []\n",
    "    for i in range(0, test_ds['image'].shape[0], config.batch_size):\n",
    "      _, test_loss, test_accuracy = apply_model(state,\n",
    "                                              test_ds['image'][i: i + config.batch_size],\n",
    "                                              test_ds['label'][i: i + config.batch_size])\n",
    "      test_losses.append(test_loss)\n",
    "      test_accs.append(test_accuracy)\n",
    "    test_loss = np.mean(test_loss)\n",
    "    test_accuracy = np.mean(test_accs)\n",
    "\n",
    "    print(\n",
    "      'epoch:% 3d, train_loss: %.4f, train_accuracy: %.2f, test_loss: %.4f, test_accuracy: %.2f'\n",
    "      % (epoch, train_loss, train_accuracy * 100, test_loss, test_accuracy * 100)\n",
    "    )\n",
    "\n",
    "    summary_writer.scalar('train_loss', train_loss, epoch)\n",
    "    summary_writer.scalar('train_accuracy', train_accuracy, epoch)\n",
    "    summary_writer.scalar('test_loss', test_loss, epoch)\n",
    "    summary_writer.scalar('test_accuracy', test_accuracy, epoch)\n",
    "\n",
    "  summary_writer.flush()\n",
    "  return state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "config = ml_collections.ConfigDict()\n",
    "\n",
    "config.learning_rate = 0.1\n",
    "config.momentum = 0.9\n",
    "config.batch_size = 128\n",
    "config.num_epochs = 10\n",
    "\n",
    "train_and_evaluate(config, './ckpt')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
